# get concept-wise heldout error
library(tidyverse)
library(network)

source("pundits_functions.R")

set.seed(1111)

get_mse <- function(obs, pred){
  mean((obs-pred)^2)
}

get_mae <- function(obs, pred){
  mean(abs(obs-pred))
}

topic_refdf <- data.frame(formal_names = c("China","Class","Climate",
                                           "Conservative","Democrat",
                                           "Far Left","Far Right",
                                           "Gender","Guns","Health Care/Insurance",
                                           "Immigration","Iran","Israel","LGBT",
                                           "Liberal","Mueller","Progressive",
                                           "Race","Reproductive Health","Republican",
                                           "Taxes and Spending",
                                           "Trade"),
                          topic = c('china','class','climate','conservative',
                                    'democrat',
                                    'far_left','far_right','gender','guns',
                                    'health_care_insurance','immigration',
                                    'iran','israel','lgbt','liberal',
                                    'mueller','progressive','race',
                                    'reproductive_health','republican',
                                    'taxes_spending','trade'),
                          group = c(2:23))

files <- list.files(path = "../output/heldout_hiernets/",
                    pattern = "modcv")

# function to get fit from held-out model
getfit <- function(path = "../output/heldout_hiernets/",
                   file){
  load(paste0(path, file))
  
  topic <- gsub("modcv_heldout_|\\.RData", "", file)
  
  
  return(data.frame(topic = topic,
                    lamhat = mod.cv$lamhat.1se,
                    cv.err = mod.cv$cv.err[which(mod.cv$lamlist == mod.cv$lamhat.1se)],
                    nonzero =  mod.cv$nonzero[which(mod.cv$lamlist == mod.cv$lamhat.1se)]))
}

# get fits
fitdf <- map_df(files, function(x){
  getfit(file =x)
})

fitdf <- fitdf %>%
  left_join(topic_refdf, by = "topic") %>%
  dplyr::select(formal_names, cv.err, nonzero)

# load baseline (all variables) model
load("../output/modcv_allvars.RData")

fitdf[nrow(fitdf)+1,] <- c("None",
                           mod.cv$cv.err[which(mod.cv$lamlist == mod.cv$lamhat.1se)],
                           mod.cv$nonzero[which(mod.cv$lamlist == mod.cv$lamhat.1se)])
fitdf$cv.err <- as.numeric(fitdf$cv.err)
fitdf$nonzero <- as.numeric(fitdf$nonzero)

# make figure 4
heldout_fits <- 
  fitdf %>%
  filter(!formal_names == "None") %>%
  arrange(desc(cv.err)) %>%
  ggplot(aes(x = fct_rev(fct_inorder(formal_names)),
             y = fitdf$cv.err[fitdf$formal_names == "None"],
             xend = fct_rev(fct_inorder(formal_names))))+
   geom_hline(aes(yintercept = fitdf$cv.err[fitdf$formal_names == "None"]),
               lty = "dashed")+
  #geom_point(aes(size = non_zero_params))+
  geom_segment(aes(yend = cv.err), 
               arrow = arrow(length = unit(0.25, "cm")))+
  coord_flip()+
  labs(y = "Cross-Validation Error",
       x = "Held-Out Concept",
       caption = "Arrows begin at baseline fit with no features held out (dashed line)")+
  theme_jg()
ggsave(heldout_fits, file = "../figures/figure_4.png", width = 10, height = 6)
ragg::agg_tiff("../figures/figure_4.tiff", 
    width=10, height=6, units = "in", 
    res = 300)
print(heldout_fits)
dev.off()

